import numpy as np
import warnings
warnings.filterwarnings('ignore')
import pandas as pd
import os
import pdb
import time
#import tf.keras.backend as K
from tqdm import tqdm
import keras_metrics
from keras.layers import Input
from keras.preprocessing import sequence
from keras.layers import LSTM, Dense, Masking, Concatenate, concatenate, Bidirectional, BatchNormalization, Dropout
from keras.utils import to_categorical
from sklearn.preprocessing import LabelEncoder
from keras.models import Sequential, Model
from keras.utils import multi_gpu_model
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_sample_weight, compute_class_weight

###
#os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
#os.environ["CUDA_VISIBLE_DEVICES"]= "1"

start = time.time()
data_path = "./training_data/"

data_list_readdepths = []
data_list_indexes = []
data_list_cnvnator_preds = []
data_list_xhmm_preds = []

files_list = os.listdir(data_path)

print("loading the training data...")
for filename in tqdm(files_list):
    with open(data_path+filename) as f:
        data = f.readlines()
        data = [x.strip() for x in data]

    indexes = [(int(x.split(',')[1][1:]),int(x.split(',')[2][1:])) for x in data]
    xhmm_preds = [x.split(',')[3][1:] for x in data]
    cnvnator_preds = [x.split(',')[4][1:].replace(']','') for x in data]
    read_depth_seqs = [x.split(',')[5:] for x in data]
    read_depth_seqs = [[y.replace('[','',1).replace(']','').replace(' ','') for y in x] for x in read_depth_seqs]
    read_depth_seqs = [[0 if not y else int(str(y)) for y in x] for x in read_depth_seqs]

    data_list_indexes.extend(indexes)
    data_list_cnvnator_preds.extend(cnvnator_preds)
    data_list_xhmm_preds.extend(xhmm_preds)
    data_list_readdepths.extend(read_depth_seqs)

end = time.time()
print("Loading of the data took ", end-start," seconds.")

lens = [len(k) for k in data_list_readdepths]
lens = np.asarray(lens)
mask = lens < 4000



#convert data lists to numpy arrays
data_list_readdepths = np.asarray(data_list_readdepths)
data_list_indexes = np.asarray(data_list_indexes)
data_list_cnvnator_preds = np.asarray(data_list_cnvnator_preds)
data_list_xhmm_preds = np.asarray(data_list_xhmm_preds)

data_list_readdepths = data_list_readdepths[mask]
data_list_indexes = data_list_indexes[mask]
data_list_cnvnator_preds = data_list_cnvnator_preds[mask]
data_list_xhmm_preds = data_list_xhmm_preds[mask]


data_list_readdepths = sequence.pad_sequences(data_list_readdepths, value=-1)
#print(data_list_readdepths)


''' 
CNVNATOR PREDS: nan -> 0
                <DUP> -> 1
                <DEL> -> 2
XHMM PREDS: 'DEL' -> 0
            'DUP' -> 1
'''
data_list_cnvnator_preds[data_list_cnvnator_preds == 'nan'] = 0
data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DUP>'"] = 1
data_list_cnvnator_preds[data_list_cnvnator_preds == "'<DEL>'"] = 2


data_list_xhmm_preds[data_list_xhmm_preds == "'del'"] = 0
data_list_xhmm_preds[data_list_xhmm_preds == "'dup'"] = 1


nocallinds = data_list_cnvnator_preds == '0'
delinds = data_list_cnvnator_preds=='2'
dupinds = data_list_cnvnator_preds=='1'
nocallindssampled= [not i if i  and np.random.uniform() < 0.89 else i for i in nocallinds]
nocallindssampled= np.asarray(nocallindssampled)
total_filter = nocallindssampled + delinds + dupinds

data_list_xhmm_preds = data_list_xhmm_preds[total_filter]
data_list_cnvnator_preds = data_list_cnvnator_preds[total_filter]
data_list_readdepths = data_list_readdepths[total_filter]

data_list_cnvnator_preds = data_list_cnvnator_preds[:70000]
data_list_readdepths = data_list_readdepths[:70000]
data_list_xhmm_preds = data_list_xhmm_preds[:70000]

#pdb.set_trace()
data_list_xhmm_preds = to_categorical(data_list_xhmm_preds, num_classes =2)
data_list_cnvnator_preds = to_categorical(data_list_cnvnator_preds, num_classes =3)

data_list_readdepths = np.expand_dims(data_list_readdepths, axis=2)

#normalize a bit.
data_list_readdepths = data_list_readdepths #/ 45000


print("Read depths data matrix shape: ", data_list_readdepths.shape)
print("conifer predictions data matrix shape: ", data_list_xhmm_preds.shape)
print("Cnvnator predictions (labels) data matrix shape: ", data_list_cnvnator_preds.shape)


'''
input1 <- data_list_xhmm_preds
input2 <- data_list_readdepths
labels <- data_list_cnvnator_preds
'''

#model
max_length = 3999 # maximum length of read depth signals
input1 = Input(shape=(2,)) # xhmm prediction
input2 = Input(shape=(max_length,1)) # read depth sequence
masked_input2 = Masking(mask_value = -1)(input2)
features1 = BatchNormalization()(masked_input2)
features2 = Bidirectional(LSTM(128))(features1)
features3 = BatchNormalization()(features2)
merged = concatenate([features3, input1])
features4 = Dense(100, activation='relu')(merged)
output = Dense(3,activation='softmax')(features4)

model = Model(inputs=[input1, input2], outputs = output)
print(model.summary())


#parallelize model with keras.utils.multi_gpu_model
model = multi_gpu_model(model, gpus=4)


data_list_xhmm_preds_train, data_list_xhmm_preds_test, \
data_list_readdepths_train, data_list_readdepths_test, \
data_list_cnvnator_preds_train, data_list_cnvnator_preds_test = train_test_split(data_list_xhmm_preds, data_list_readdepths, data_list_cnvnator_preds, test_size=0.1, random_state=35)

np.save('./outputs/data_list_codex2_preds_test_weightedcrossent.npy', data_list_xhmm_preds_test)
np.save('./outputs/data_list_readdepths_test_weightedcrossent.npy', data_list_readdepths_test)
np.save('./outputs/data_list_cnvnator_preds_test_weightedcrossent.npy', data_list_cnvnator_preds_test)


y_integers = np.argmax(data_list_cnvnator_preds_train, axis=1)
class_weights = compute_class_weight('balanced', np.unique(y_integers), y_integers)
d_class_weights = dict(enumerate(class_weights))


model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy', keras_metrics.categorical_precision(), keras_metrics.categorical_recall()])
model.fit([data_list_xhmm_preds_train, data_list_readdepths_train], data_list_cnvnator_preds_train, validation_split = 0.2, epochs = 30, batch_size=256, class_weight=d_class_weights)
model.save('./outputs/deepXCNV_batchnorm_bilstm128_batchnorm_dense100_dense3_bs256_padding-1_30epochs_traintestsplitted_weightedcrossent_codex2.h5')


